import argparse, os, json, numpy as np, pandas as pd
from .config import load_study_config
from .io_inputs import read_prestacked, parse_RG_bin_mid
from .plateau import select_flat_window

def main():
    ap = argparse.ArgumentParser(description="Run T3 (lensing plateau & size–amplitude).")
    ap.add_argument("--config", default="config/study.yaml")
    args = ap.parse_args()
    cfg = load_study_config(args.config)

    os.makedirs("outputs", exist_ok=True); os.makedirs("figures", exist_ok=True)

    df, meta = read_prestacked(cfg.prestack_csv, cfg.prestack_meta_csv if cfg.prestack_meta_csv else None)

    rows=[]; windows={}; flatness={}
    for sid, sub in df.groupby("stack_id"):
        b = sub["b"].values; gt = sub["gamma_t"].values
        if "geo_norm" in sub.columns and cfg.use_geo_in_amplitude:
            P = gt * b * sub["geo_norm"].values
        else:
            P = gt * b
        i0,i1,ok,stats = select_flat_window(b, P, cfg.plateau_slope_abs_max, cfg.min_bins, cfg.min_b, cfg.max_b)
        if not ok:
            rows.append({"stack_id": sid, "R_G_bin": sub["R_G_bin"].iloc[0], "Mstar_bin": sub["Mstar_bin"].iloc[0],
                         "A_theta": np.nan, "A_theta_CI_low": np.nan, "A_theta_CI_high": np.nan,
                         "rmse_flat": np.nan, "R2_flat": np.nan, "n_lenses": None, "claimable": False})
            continue
        # bootstrap CI on window bins
        rng = np.random.default_rng(cfg.random_seed)
        Pwin = P[i0:i1+1]
        boot=[]
        for _ in range(cfg.bootstrap_n):
            idx = rng.integers(0, len(Pwin), size=len(Pwin))
            boot.append(float(np.median(Pwin[idx])))
        lo, hi = np.percentile(boot, [16,84])
        rows.append({"stack_id": sid, "R_G_bin": sub["R_G_bin"].iloc[0], "Mstar_bin": sub["Mstar_bin"].iloc[0],
                     "A_theta": float(stats["A_theta"]), "A_theta_CI_low": float(lo), "A_theta_CI_high": float(hi),
                     "rmse_flat": float(stats["rmse_flat"]), "R2_flat": float(stats["R2_flat"]),
                     "n_lenses": None, "claimable": True})
        windows[sid]={"i0": int(i0), "i1": int(i1)}
        flatness[sid]={"rmse_flat": float(stats["rmse_flat"]), "R2_flat": float(stats["R2_flat"])}

    out = pd.DataFrame(rows)
    out.to_csv("outputs/lensing_plateau.csv", index=False)
    with open("outputs/windows.json","w") as f: json.dump(windows, f, indent=2)
    with open("outputs/flatness.json","w") as f: json.dump(flatness, f, indent=2)

    # size–amplitude regression at fixed Mstar bins
    reg={}
    for mstar, d in out[out["claimable"]==True].groupby("Mstar_bin"):
        R = d["R_G_bin"].apply(parse_RG_bin_mid).values.astype(float)
        A = d["A_theta"].values.astype(float)
        m = np.isfinite(R)&np.isfinite(A)
        if np.sum(m) < 2: continue
        X = np.vstack([R[m], np.ones(np.sum(m))]).T
        beta,_,_,_ = np.linalg.lstsq(X, A[m], rcond=None)
        s,b = float(beta[0]), float(beta[1])
        rng = np.random.default_rng(cfg.random_seed)
        slopes=[]
        for _ in range(cfg.bootstrap_n):
            idx = rng.integers(0, np.sum(m), size=np.sum(m))
            Xb = X[idx]; yb = A[m][idx]
            betab,_,_,_ = np.linalg.lstsq(Xb, yb, rcond=None)
            slopes.append(float(betab[0]))
        lo,hi = np.percentile(slopes, [16,84])
        reg[mstar]={"slope_Atheta_vs_RG": s, "slope_CI_16_84":[float(lo), float(hi)], "n_stacks": int(np.sum(m))}
    with open("outputs/size_regression.json","w") as f: json.dump(reg, f, indent=2)

if __name__=="__main__":
    main()
